{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7765UFHoyGx6"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "KsOkK8O69PyT"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZS8z-_KeywY9"
      },
      "source": [
        "# TF Lattice 自定义 Estimator"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r61fkA2i9Y3_"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://tensorflow.google.cn/lattice/tutorials/custom_estimators\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\">在 TensorFlow.org 上查看 </a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/lattice/tutorials/custom_estimators.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\">在 Google Colab 中运行 </a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/lattice/tutorials/custom_estimators.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\">在 GitHub 中查看源代码</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/lattice/tutorials/custom_estimators.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\"> 下载笔记本</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ur6yCw7YVvr8"
      },
      "source": [
        "## 概述\n",
        "\n",
        "您可以使用自定义 Estimator 通过 TFL 层创建任意单调模型。本指南概述了创建此类 Estimator 所需的步骤。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x769lI12IZXB"
      },
      "source": [
        "## 设置"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fbBVAR6UeRN5"
      },
      "source": [
        "安装 TF Lattice 软件包："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bpXjJKpSd3j4"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install tensorflow-lattice"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jSVl9SHTeSGX"
      },
      "source": [
        "导入所需的软件包："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "P9rMpg1-ASY3"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import logging\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import sys\n",
        "import tensorflow_lattice as tfl\n",
        "from tensorflow import feature_column as fc\n",
        "\n",
        "from tensorflow_estimator.python.estimator.canned import optimizers\n",
        "from tensorflow_estimator.python.estimator.head import binary_class_head\n",
        "logging.disable(sys.maxsize)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "svPuM6QNxlrH"
      },
      "source": [
        "下载 UCI Statlog (Heart) 数据集："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "M0CmH1gPASZF"
      },
      "outputs": [],
      "source": [
        "csv_file = tf.keras.utils.get_file(\n",
        "    'heart.csv', 'http://storage.googleapis.com/applied-dl/heart.csv')\n",
        "df = pd.read_csv(csv_file)\n",
        "target = df.pop('target')\n",
        "train_size = int(len(df) * 0.8)\n",
        "train_x = df[:train_size]\n",
        "train_y = target[:train_size]\n",
        "test_x = df[train_size:]\n",
        "test_y = target[train_size:]\n",
        "df.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nKkAw12SxvGG"
      },
      "source": [
        "设置用于在本指南中进行训练的默认值："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "1T6GFI9F6mcG"
      },
      "outputs": [],
      "source": [
        "LEARNING_RATE = 0.1\n",
        "BATCH_SIZE = 128\n",
        "NUM_EPOCHS = 1000"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0TGfzhPHzpix"
      },
      "source": [
        "## 特征列\n",
        "\n",
        "与任何其他 TF Estimator 一样，数据通常需要通过 input_fn 传递给 Estimator，并使用 [FeatureColumns](https://tensorflow.google.cn/guide/feature_columns) 进行解析。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DCIUz8apzs0l"
      },
      "outputs": [],
      "source": [
        "# Feature columns.\n",
        "# - age\n",
        "# - sex\n",
        "# - ca        number of major vessels (0-3) colored by flourosopy\n",
        "# - thal      3 = normal; 6 = fixed defect; 7 = reversable defect\n",
        "feature_columns = [\n",
        "    fc.numeric_column('age', default_value=-1),\n",
        "    fc.categorical_column_with_vocabulary_list('sex', [0, 1]),\n",
        "    fc.numeric_column('ca'),\n",
        "    fc.categorical_column_with_vocabulary_list(\n",
        "        'thal', ['normal', 'fixed', 'reversible']),\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hEZstmtT2CA3"
      },
      "source": [
        "请注意，分类特征不需要用密集特征列包装，因为 `tfl.laysers.CategoricalCalibration` 层可以直接使用分类索引。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H_LoW_9m5OFL"
      },
      "source": [
        "## 创建 input_fn\n",
        "\n",
        "与任何其他 Estimator 一样，您可以使用 input_fn 将数据馈送给模型进行训练和评估。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lFVy1Efy5NKD"
      },
      "outputs": [],
      "source": [
        "train_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=train_x,\n",
        "    y=train_y,\n",
        "    shuffle=True,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=NUM_EPOCHS,\n",
        "    num_threads=1)\n",
        "\n",
        "test_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=test_x,\n",
        "    y=test_y,\n",
        "    shuffle=False,\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=1,\n",
        "    num_threads=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kbrgSr9KaRg0"
      },
      "source": [
        "## 创建 model_fn\n",
        "\n",
        "您可以通过多种方式创建自定义 Estimator。在这里，我们将构造一个在已解析的输入张量上调用 Keras 模型的 `model_fn`。要解析输入特征，您可以使用 `tf.feature_column.input_layer`、`tf.keras.layers.DenseFeatures` 或 `tfl.estimators.transform_features`。如果使用后者，则不需要使用密集特征列包装分类特征，并且生成的张量不会串联，这样可以更轻松地在校准层中使用特征。\n",
        "\n",
        "要构造模型，您可以搭配使用 TFL 层或任何其他 Keras 层。在这里，我们从 TFL 层创建一个校准点阵 Keras 模型，并施加一些单调性约束。随后，我们使用 Keras 模型创建自定义 Estimator。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n2Zrv6OPaQO2"
      },
      "outputs": [],
      "source": [
        "def model_fn(features, labels, mode, config):\n",
        "  \"\"\"model_fn for the custom estimator.\"\"\"\n",
        "  del config\n",
        "  input_tensors = tfl.estimators.transform_features(features, feature_columns)\n",
        "  inputs = {\n",
        "      key: tf.keras.layers.Input(shape=(1,), name=key) for key in input_tensors\n",
        "  }\n",
        "\n",
        "  lattice_sizes = [3, 2, 2, 2]\n",
        "  lattice_monotonicities = ['increasing', 'none', 'increasing', 'increasing']\n",
        "  lattice_input = tf.keras.layers.Concatenate(axis=1)([\n",
        "      tfl.layers.PWLCalibration(\n",
        "          input_keypoints=np.linspace(10, 100, num=8, dtype=np.float32),\n",
        "          # The output range of the calibrator should be the input range of\n",
        "          # the following lattice dimension.\n",
        "          output_min=0.0,\n",
        "          output_max=lattice_sizes[0] - 1.0,\n",
        "          monotonicity='increasing',\n",
        "      )(inputs['age']),\n",
        "      tfl.layers.CategoricalCalibration(\n",
        "          # Number of categories including any missing/default category.\n",
        "          num_buckets=2,\n",
        "          output_min=0.0,\n",
        "          output_max=lattice_sizes[1] - 1.0,\n",
        "      )(inputs['sex']),\n",
        "      tfl.layers.PWLCalibration(\n",
        "          input_keypoints=[0.0, 1.0, 2.0, 3.0],\n",
        "          output_min=0.0,\n",
        "          output_max=lattice_sizes[0] - 1.0,\n",
        "          # You can specify TFL regularizers as tuple\n",
        "          # ('regularizer name', l1, l2).\n",
        "          kernel_regularizer=('hessian', 0.0, 1e-4),\n",
        "          monotonicity='increasing',\n",
        "      )(inputs['ca']),\n",
        "      tfl.layers.CategoricalCalibration(\n",
        "          num_buckets=3,\n",
        "          output_min=0.0,\n",
        "          output_max=lattice_sizes[1] - 1.0,\n",
        "          # Categorical monotonicity can be partial order.\n",
        "          # (i, j) indicates that we must have output(i) &lt;= output(j).\n",
        "          # Make sure to set the lattice monotonicity to 'increasing' for this\n",
        "          # dimension.\n",
        "          monotonicities=[(0, 1), (0, 2)],\n",
        "      )(inputs['thal']),\n",
        "  ])\n",
        "  output = tfl.layers.Lattice(\n",
        "      lattice_sizes=lattice_sizes, monotonicities=lattice_monotonicities)(\n",
        "          lattice_input)\n",
        "\n",
        "  training = (mode == tf.estimator.ModeKeys.TRAIN)\n",
        "  model = tf.keras.Model(inputs=inputs, outputs=output)\n",
        "  logits = model(input_tensors, training=training)\n",
        "\n",
        "  if training:\n",
        "    optimizer = optimizers.get_optimizer_instance_v2('Adagrad', LEARNING_RATE)\n",
        "  else:\n",
        "    optimizer = None\n",
        "\n",
        "  head = binary_class_head.BinaryClassHead()\n",
        "  return head.create_estimator_spec(\n",
        "      features=features,\n",
        "      mode=mode,\n",
        "      labels=labels,\n",
        "      optimizer=optimizer,\n",
        "      logits=logits,\n",
        "      trainable_variables=model.trainable_variables,\n",
        "      update_ops=model.updates)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mng-VtsSbVtQ"
      },
      "source": [
        "## 训练和 Estimator\n",
        "\n",
        "使用 `model_fn`，我们可以创建和训练 Estimator。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j38GaEbKbZju"
      },
      "outputs": [],
      "source": [
        "estimator = tf.estimator.Estimator(model_fn=model_fn)\n",
        "estimator.train(input_fn=train_input_fn)\n",
        "results = estimator.evaluate(input_fn=test_input_fn)\n",
        "print('AUC: {}'.format(results['auc']))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "custom_estimators.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
